import numpy as np
from typing import List, Tuple
from tqdm import tqdm
import os
import jax.numpy as jnp
from utils import generate_random_grid, pad_grid, apply_transformation

from v0_programs import *


def generate_dataset(B: int, N: int, grid_size: int, colors: int) -> Tuple[np.ndarray, np.ndarray]:
    """
    Generate a dataset of transformations.

    Args:
        B: Batch size
        N: Number of samples (input-output pairs) per batch
        grid_size: Size of each grid
        colors: Number of colors in the grids

    Returns:
        data: Numpy array of shape (B, N, 30, 30, 2)
        grid_shapes: Numpy array of shape (B, N, 2, 2)  # [input_shape, output_shape]
    """
    data = np.zeros((B, N, 30, 30, 2), dtype=np.int32)
    grid_shapes = np.full((B, N, 2, 2), grid_size, dtype=np.int32)  # Modify shape to (B, N, 2, 2)

    for b in tqdm(range(B), desc="Generating batches"):
        idx = np.random.randint(len(TRANSFORMATIONS))  # Random index for the transformation list
        transformation = TRANSFORMATIONS[idx]

        for n in range(N):
            input_grid = generate_random_grid(grid_size, colors)
            output_grid = apply_transformation(input_grid, transformation)
            data[b, n, :, :, 0] = pad_grid(input_grid, target_shape=(30, 30))
            data[b, n, :, :, 1] = pad_grid(output_grid, target_shape=(30, 30))
            grid_shapes[b, n] = np.stack([input_grid.shape, output_grid.shape], axis=-1)
    return data, grid_shapes


if __name__ == "__main__":

    # This creates a dataset using one fixed gird size of arc style tasks with transformations
    # drawn from a specified set of v_0 programs, inputs grids are samples from a random grid generator

    # List of all transformation functions to use in generation
    TRANSFORMATIONS = [rotate_90, rotate_180, rotate_270]

    # Creating the directory for dataset version 0
    os.makedirs("src/datasets/v0_train", exist_ok=True)
    os.makedirs("src/datasets/v0_test", exist_ok=True)

    # Set the parameters for the dataset
    B_train = 100000  # Number of training examples
    B_test = 1000  # Number of testing examples
    N = 3  # Number of samples per batch, set to 1 for individual examples
    grid_size = 5
    colors = 10
    seed = 42

    print("Starting dataset generation...")

    # Seed seed to control randomness
    np.random.seed(seed)
    # Generate the training dataset
    dataset_train, grid_shapes_train = generate_dataset(B_train, N, grid_size, colors)
    # Generate the testing dataset
    dataset_test, grid_shapes_test = generate_dataset(B_test, N, grid_size, colors)

    # Save the datasets
    np.save(f"src/datasets/v0_train/grids.npy", dataset_train.astype(jnp.uint8))
    np.save(f"src/datasets/v0_train/shapes.npy", grid_shapes_train.astype(jnp.uint8))
    np.save(f"src/datasets/v0_test/grids.npy", dataset_test.astype(jnp.uint8))
    np.save(f"src/datasets/v0_test/shapes.npy", grid_shapes_test.astype(jnp.uint8))

    print("Dataset and grid shapes saved to 'dataset_v0' folder.")
    print("Train dataset of shape:", dataset_train.shape)
    print("Train shapes of shape:", grid_shapes_train.shape)
    print("Test dataset of shape:", dataset_test.shape)
    print("Test shapes of shape:", grid_shapes_test.shape)
